package cn.atcoder.air.transport;

import cn.atcoder.air.client.ClientChannelHandler;
import cn.atcoder.air.client.MessageFuture;
import cn.atcoder.air.exception.ClientTimeoutException;
import cn.atcoder.air.exception.MessageException;
import cn.atcoder.air.msg.*;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.handler.timeout.IdleStateHandler;

import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

/**
 * @author yangjunda1
 * @description
 * @date 5/9/19 5:21 PM
 */
public class TcpClientTransport extends AbstractClientTransport {

    private final ConcurrentHashMap<Long, MessageFuture> futureMap = new ConcurrentHashMap<>();

    private AtomicLong requestId = new AtomicLong(0);

    public TcpClientTransport(String host, int port) {
        super(host, port);
    }

    @Override
    protected ResponseMessage send(RequestMessage msg, int timeout) {
        long msgId = generateRequestId();
        try {
            super.currentRequests.incrementAndGet();
            msg.setMessageId(msgId);

            MessageFuture<ResponseMessage> future = doSendFuture(msg, timeout);
            return future.get(timeout, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            throw new MessageException("Client request thread interrupted");
        } catch (ClientTimeoutException e) {
            try {
                futureMap.remove(msgId);
            } catch (Exception e1) {
                LOGGER.error(e1.getMessage(), e1);
            }
            throw e;
        } finally {
            super.currentRequests.decrementAndGet();
        }
    }

    /**
     * handle the Response
     */
    public void receiveResponse(ResponseMessage msg) {
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("receiveResponse..{}", msg);
        }
        Long msgId = msg.getMessageId();
        MessageFuture future = futureMap.get(msgId);
        if (future == null) {
            LOGGER.warn("Not found future which msgId is {} when receive response. May be " +
                    "this future have been removed because of timeout", msgId);
        } else {
            future.setSuccess(msg);
            futureMap.remove(msgId);
        }
    }


    @Override
    void start0() {
        bootstrap = new Bootstrap();
        bootstrap.channel(NioSocketChannel.class);
        bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
        bootstrap.group(workGroup);
        bootstrap.remoteAddress(host, port);
        bootstrap.handler(new ChannelInitializer<SocketChannel>() {
            @Override
            protected void initChannel(SocketChannel socketChannel) throws Exception {
                socketChannel.pipeline().addLast(new IdleStateHandler(20, 10, 0));
                socketChannel.pipeline().addLast(new ObjectEncoder());
                socketChannel.pipeline().addLast(new ObjectDecoder(ClassResolvers.cacheDisabled(null)));
                socketChannel.pipeline().addLast(new ClientChannelHandler());
                ClientTransportFactory.build(TcpClientTransport.this);
            }
        });
        doConnect();
    }

    @Override
    public void doConnect() {
        if (Objects.nonNull(channel) && channel.isActive()) {
            return;
        }
        ChannelFuture future = bootstrap.connect(host, port);
        future.addListener((ChannelFutureListener) channelFuture -> {
            if (channelFuture.isSuccess()) {
                channel = future.channel();
                LOGGER.info("Connect to provider:{} success! The connection is {} -> {}", getRemoteAddress(), channel.localAddress(), getRemoteAddress());
            } else {
                LOGGER.info("连接到服务器失败，10秒后尝试重连");
                channelFuture.channel().eventLoop().schedule(this::doConnect, 10, TimeUnit.SECONDS);
            }
        });
    }

    @Override
    public void shutdown() {

    }

    @Override
    public String getRemoteAddress() {
        return host + ":" + port;
    }

    @Override
    public boolean isOpen() {
        return Objects.nonNull(channel) && channel.isActive();
    }

    /**
     * different FutureMap for different Request msg type
     */
    private void addFuture(BaseMessage msg, MessageFuture msgFuture) {
        MessageType msgType = msg.getMessageType();
        Long msgId = msg.getMessageId();
        if (msgType == MessageType.CALLBACK_REQUEST_MSG
                || msgType == MessageType.REGISTER_REQUEST_MSG
                || msgType == MessageType.HEARTBEAT_REQUEST_MSG) {
            this.futureMap.put(msgId, msgFuture);

        } else {
            LOGGER.error("cannot handle Future for this Msg:{}", msg);
        }
    }

    private MessageFuture<ResponseMessage> doSendFuture(RequestMessage msg, int timeout) {

        MessageFuture<ResponseMessage> future = new MessageFuture<>(msg.getHeader(), timeout, channel);
        if (msg instanceof CallbackRequestMessage) {
            future.setInvocationBody(((CallbackRequestMessage) msg).getInvocationBody());
        }
        channel.writeAndFlush(msg);
        // 置为已发送
        future.setSentTime(System.currentTimeMillis());
        this.addFuture(msg, future);
        return future;
    }


    private long generateRequestId() {
        return requestId.getAndIncrement();
    }
}
